from get_closure import get_optimizer_closure
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torch
import copy
import sys
sys.path.append('./')


def average(params_lst):
    with torch.no_grad():
        num_params = len(params_lst[0])
        averaged = [0]*num_params
        for i in range(num_params):
            for params in params_lst:
                averaged[i] += params[i]/len(params_lst)
    return averaged


def L2_distance(params_lst1, params_lst2):
    ans = 0
    for p1, p2 in zip(params_lst1, params_lst2):
        ans += torch.norm(p1-p2) ** 2
    return ans ** 0.5


class GradientCalculator(optim.Optimizer):
    def __init__(self, model, weight_decay):
        super(GradientCalculator, self).__init__(
            params=model.parameters(), defaults={})
        self._model = model
        self._closure = get_optimizer_closure(model)
        self._weight_decay = weight_decay

    def _add_l2(self, grad, p):
        return grad.add_(p, alpha=self._weight_decay)

    def _compute_stochastic_grad(self, input, target):
        input, target = input.cuda(), target.cuda()
        self._closure(input, target)
        grads = []
        for i, p in enumerate(self.param_groups[0]["params"]):
            if p.grad is None:
                grads.append(0)
            grads.append(self._add_l2(grad=p.grad, p=p))
        return grads

    def copy_params(self, group):
        for (p_to, p_from) in zip(self.param_groups[0]["params"], group["params"]):
            p_to.data = copy.deepcopy(p_from.data)


class Client(GradientCalculator): # it adds to the gradient calculator functionality the ability to train/update
    def __init__(self, model, eta, weight_decay, train_loader, n_iters, **kargs):
        super(Client, self).__init__(model, weight_decay)
        self._eta = eta # learning rate
        self._train_loader = train_loader
        self._n_iters = n_iters # this is K

    def train(self, iters):
        data_iterator = self._train_loader.__iter__()
        for _ in range(iters): # Making K local steps
            try:
                input, target = data_iterator.next()
            except StopIteration:
                data_iterator = self._train_loader.__iter__()
                input, target = data_iterator.next()
            grads = self._compute_grad_estimator(input, target) # compute the gradients for the local update
            self._update(grads)

    def _compute_grad_estimator(self, input, target):
        raise NotImplementedError

    def _update(self, grads, noise=0.0):
        with torch.no_grad():
            for i, p in enumerate(self.param_groups[0]["params"]):
                if p.grad is None:
                    continue
                p.add_(grads[i], alpha=-self._eta) # this is where the local update is made
                p.add_(torch.randn(p.size()).cuda(), alpha=noise)


class SGD_Client(Client): # doesn't need instance variables as everything already exists in parent methods
    def _compute_grad_estimator(self, input, target):
        return self._compute_stochastic_grad(input, target)


class SVRG_Client(Client):
    def __init__(self, **kargs):
        super(SVRG_Client, self).__init__(**kargs)
        self._initial_snap_shot = GradientCalculator(model=copy.deepcopy(
            kargs["model"]), weight_decay=kargs["weight_decay"])
        self._snap_shot = GradientCalculator(model=copy.deepcopy(
            kargs["model"]), weight_decay=kargs["weight_decay"])
        self._ref_grads = None
        self._delta = 0.0

    def compute_full_grad(self): # just calculates full gradient of the client data
        tmp_loader = torch.utils.data.DataLoader(self._train_loader.dataset, batch_size=len(
            self._train_loader.dataset), shuffle=False).__iter__()
        input, target = tmp_loader.next()
        return self._compute_stochastic_grad(input, target)

    def compute_large_batch_vr_grad(self, ref_grads): # this just loads a large batch and sends it to another method
        K = self._n_iters # note that we need to pass ref_grads
        b = self._train_loader.batch_size
        tmp_loader = torch.utils.data.DataLoader(# load large batch
            self._train_loader.dataset, batch_size=K*b, shuffle=True).__iter__()
        input, target = tmp_loader.next()
        return self._compute_general_grad_estimator(input, target,
                                                    self._initial_snap_shot,
                                                    ref_grads)

    def _compute_grad_estimator(self, input, target):
        return self._compute_general_grad_estimator(input=input, target=target,
                                                    snap_shot=self._snap_shot,
                                                    ref_grads=self._ref_grads)

    def _compute_general_grad_estimator(self, input, target, snap_shot, ref_grads): # reference grads are some running grad estimates
        assert self._ref_grads is not None # ref grads need to be set before calling this function
        current_grads = self._compute_stochastic_grad(input, target)
        prev_grads = snap_shot._compute_stochastic_grad(input, target) # it uses the snapshot model to get previous gradients
        vr_grads = [current_grads[i] - prev_grads[i] + ref_grads[i]
                    + self._delta * torch.randn(prev_grads[i].size()).cuda()
                    for i in range(len(current_grads))]
        return vr_grads # variance reduced gradients

    def set_ref_grads(self, grads): # takes a set of grads and sets the reference grads to those
        self._ref_grads = []
        for grad in grads:
            self._ref_grads.append(copy.deepcopy(grad.data))


class SARAH_Client(SVRG_Client):
    def __init__(self, **kargs):
        super(SARAH_Client, self).__init__(**kargs)

    def _compute_grad_estimator(self, input, target):
        vr_grads = super(SARAH_Client, self)._compute_grad_estimator(
            input, target)
        self._snap_shot.copy_params(self.param_groups[0]) # saving the current parameters in the snapshot
        self.set_ref_grads(vr_grads) # setting the ref grads to be the variance reduced gradients
        return vr_grads


class STORM_Client(Client):
    def __init__(self, **kargs):
        super(STORM_Client, self).__init__(**kargs)
        self._initial_snap_shot = GradientCalculator(model=copy.deepcopy(
            kargs["model"]), weight_decay=kargs["weight_decay"])
        self._snap_shot = GradientCalculator(model=copy.deepcopy(
            kargs["model"]), weight_decay=kargs["weight_decay"])
        self._ref_grads = None
    
    def compute_initial_large_batch_grad(self):
        K = self._n_iters
        b = self._train_loader.batch_size
        tmp_loader = torch.utils.data.DataLoader(
            self._train_loader.dataset, batch_size=K*b, shuffle=True).__iter__()
        input, target = tmp_loader.next()
        current_grads = self._compute_stochastic_grad(input, target)
        vr_grads = [current_grads[i] for i in range(len(current_grads))]
        return vr_grads
    
    def compute_large_batch_vr_grad(self, ref_grads, beta):
        K = self._n_iters
        b = self._train_loader.batch_size
        tmp_loader = torch.utils.data.DataLoader(
            self._train_loader.dataset, batch_size=K*b, shuffle=True).__iter__()
        input, target = tmp_loader.next()
        return self._compute_general_grad_estimator(input, target,
                                                    self._initial_snap_shot,
                                                    ref_grads, beta)

    def _compute_grad_estimator(self, input, target):
        vr_grads = self._compute_general_grad_estimator(input=input, target=target,
                                                    snap_shot=self._snap_shot,
                                                    ref_grads=self._ref_grads, beta=0) # because in the local update there is no momentum
        self._snap_shot.copy_params(self.param_groups[0])
        self.set_ref_grads(vr_grads)
        return vr_grads

    def _compute_general_grad_estimator(self, input, target, snap_shot, ref_grads, beta):
        assert self._ref_grads is not None
        current_grads = self._compute_stochastic_grad(input, target)
        prev_grads = snap_shot._compute_stochastic_grad(input, target)
        vr_grads = [current_grads[i] + (1 - beta) * (ref_grads[i] - prev_grads[i])
                    for i in range(len(current_grads))]
        return vr_grads

    def set_ref_grads(self, grads):
        self._ref_grads = []
        for grad in grads:
            self._ref_grads.append(copy.deepcopy(grad.data))


class Server:
    def __init__(self, model, eta, weight_decay, train_loaders,
                 n_local_iters, n_workers, **kargs):
        self._model = model
        self._n_workers = n_workers
        self._optimizers = [self._get_optimizer(model=copy.deepcopy(self._model),
                                                eta=eta, weight_decay=weight_decay,
                                                train_loader=train_loaders[i],
                                                n_iters=n_local_iters)
                            for i in range(self._n_workers)]
        self._update_count = 0

    def update(self):
        self._update()
        self._update_count += 1

    def get_model(self):
        return self._model

    def _update(self):
        raise NotImplementedError

    def _get_optimizer(self, **kargs):
        raise NotImplementedError


class LocalVRSGD_Server(Server):

    def __init__(self, **kargs):
        super(LocalVRSGD_Server, self).__init__(**kargs)
        batch_size = kargs["train_loaders"][0].batch_size
        n_local_data = len(kargs["train_loaders"][0].dataset)
        self.local_iters = kargs["n_local_iters"]
        self._global_grad_compute_intvl = 1 + \
            n_local_data // (self.local_iters * batch_size)
        self._ref_grads = None

    def _update(self, **kargs):
        i = np.random.choice(self._n_workers) # choose a random machine to make the update
        self._communicate_approx_global_grad() # calculate vtilde and send it to this machine
        for optimizer in self._optimizers: 
            optimizer._initial_snap_shot.copy_params(
                self._optimizers[i].param_groups[0])
        self._optimizers[i].train(iters=self.local_iters)
        self._communicate_params(i) # communicate the final parameter from i to all machines
        self._model = copy.deepcopy(self._optimizers[0]._model)

    def _communicate_params(self, i):
        for optimizer in self._optimizers:
            optimizer.copy_params(self._optimizers[i].param_groups[0])
            optimizer._snap_shot.copy_params(
                self._optimizers[i].param_groups[0])

    def _communicate_approx_global_grad(self):
        if self._update_count % self._global_grad_compute_intvl == 0:
            params_lst = [optimizer.compute_full_grad()
                          for optimizer in self._optimizers]
        else:
            assert self._ref_grads is not None
            params_lst = [optimizer.compute_large_batch_vr_grad(ref_grads=self._ref_grads)
                          for optimizer in self._optimizers]
        averaged = average(params_lst)
        self._ref_grads = averaged

        for optimizer in self._optimizers:
            optimizer.set_ref_grads(averaged)


class LocalSVRG_Server(LocalVRSGD_Server):
    def _get_optimizer(self, **kargs):
        return SVRG_Client(**kargs)


# BVR-L-SGD
class LocalSARAH_Server(LocalVRSGD_Server):
    def _get_optimizer(self, **kargs):
        return SARAH_Client(**kargs)

# OUR-FO


class LocalSTORM_Server(Server):
    def __init__(self, **kargs):
        super(LocalSTORM_Server, self).__init__(**kargs)
        self._ref_grads = None
        self.beta = kargs["beta"] # note that STORM has a momentum parameter
        self.local_iters = kargs["n_local_iters"]
        
    def _get_optimizer(self, **kargs):
        return STORM_Client(**kargs)    
    
    def _update(self, **kargs):
        i = np.random.choice(self._n_workers)
        self._communicate_approx_global_grad(beta=self.beta)
        for optimizer in self._optimizers:
            optimizer._initial_snap_shot.copy_params(
                self._optimizers[i].param_groups[0])
        if self._update_count == 0:
            self._optimizers[i].train(iters=1)
        else:
            self._optimizers[i].train(iters = self.local_iters)
        self._communicate_params(i)
        self._model = copy.deepcopy(self._optimizers[0]._model)

    def _communicate_params(self, i):
        for optimizer in self._optimizers:
            optimizer.copy_params(self._optimizers[i].param_groups[0])
            optimizer._snap_shot.copy_params(
                self._optimizers[i].param_groups[0])

    def _communicate_approx_global_grad(self, beta):
        if self._update_count == 0:
            params_lst = [optimizer.compute_initial_large_batch_grad()
                          for optimizer in self._optimizers]
        else:
            assert self._ref_grads is not None
            params_lst = [optimizer.compute_large_batch_vr_grad(ref_grads=self._ref_grads, beta=self.beta)
                          for optimizer in self._optimizers]
        averaged = average(params_lst)
        self._ref_grads = averaged

        for optimizer in self._optimizers:
            optimizer.set_ref_grads(averaged)


class LocalSGD_Server(Server):

    def _update(self):
        for optimizer in self._optimizers:
            optimizer.train()
        self._communicate_params()
        self._model = copy.deepcopy(self._optimizers[0]._model)

    def _get_optimizer(self, **kargs):
        return SGD_Client(**kargs)

    def _communicate_params(self):
        params_lst = [optimizer.param_groups[0]["params"]
                      for optimizer in self._optimizers]
        averaged = average(params_lst)
        for optimizer in self._optimizers:
            optimizer.copy_params({"params": averaged})
